package byas;

import com.google.common.collect.Lists;
import com.google.common.collect.Sets;
import org.apache.commons.math3.linear.MatrixUtils;
import org.apache.commons.math3.linear.RealMatrix;
import org.apache.commons.math3.linear.RealVector;
import org.lionsoul.jcseg.ASegment;
import org.lionsoul.jcseg.core.*;

import java.io.IOException;
import java.util.ArrayList;
import java.util.HashSet;

import static org.apache.commons.math3.util.FastMath.log;

/**
 * 作者: LDL
 * 功能说明: 贝叶斯
 * 创建日期: 2015/6/8 10:12
 */
public class Bayes {
    //创建JcsegTaskConfig分词任务实例
    //即从jcseg.properties配置文件中初始化的配置
    public static JcsegTaskConfig config = new JcsegTaskConfig();
    public static ADictionary dic = DictionaryFactory
            .createDefaultDictionary(config);

    //生成数据
    public static Object[] createdata() throws IOException {
        ArrayList<ArrayList<String>> retList = Lists.newArrayList();
        ArrayList<Integer> labels = Lists.newArrayList();
        ASegment seg = null;
        try {
            seg = (ASegment) SegmentFactory
                    .createJcseg(JcsegTaskConfig.SIMPLE_MODE,
                            new Object[]{config, dic});
        } catch (JcsegException e) {
            e.printStackTrace();
        }

        /*IWord word;
        while ( (word = seg.next()) != null ) {
            System.out.println(word.getValue());
        }
        /*String title = article.getTitle();
        String content = article.getContent();

        List<Term> termList = new ArrayList<Term>();
        List<String> wordList = new ArrayList<String>();
        Map<String,Set<String>> words = new HashMap<String, Set<String>>();
        Queue<String> que = new LinkedList<String>();
        try {
            if(seg!=null){
                seg.reset(new StringReader(title + content));
                IWord word;
                while ( (word = seg.next()) != null ) {
                    if(shouldInclude(word.getValue())){
                        wordList.add(word.getValue());
                    }
                }
            }

        } catch (IOException e) {
            e.printStackTrace();
        }*/

        /*retList.add(Lists.newArrayList("my", "dog", "has", "flea", "problems", "help", "please"));
        retList.add(Lists.newArrayList("maybe", "not", "take", "him", "to", "dog", "park", "stupid"));
        retList.add(Lists.newArrayList("my", "dalmation", "is", "so", "cute", "I", "love", "him"));
        retList.add(Lists.newArrayList("stop", "posting", "stupid", "worthless", "garbage"));
        retList.add(Lists.newArrayList("mr", "licks", "ate", "my", "steak", "how", "to", "stop", "him"));
        retList.add(Lists.newArrayList("quit", "buying", "worthless", "dog", "food", "stupid"));
        ArrayList<Integer> labels = Lists.newArrayList(0,1,0,1,0,1);*/
        return new Object[]{retList,labels};
    }


    //获取单词set
    public static ArrayList<String> createVocabSet(ArrayList<ArrayList<String>> lists){
        HashSet<String> retSet = Sets.newHashSet();
        for(ArrayList<String> list : lists){
            for(String str : list){
                retSet.add(str);
            }
        }

        return Lists.newArrayList(retSet);
    }

    //计算set中包含的单词数量
    public static double[] bagOfWords2VecMN(ArrayList<String> set,ArrayList<String> inputData){
        double[] returnVec = new double[set.size()];
        for (int i = 0; i < inputData.size(); i++) {
            if(set.contains(inputData.get(i))){
                returnVec[set.indexOf(inputData.get(i))]++;
            }
        }
        return returnVec;
    }

    //训练
    public static Object[] trainNB(RealMatrix realMatrix,ArrayList<Integer> labels){

        int numTrainDocs = realMatrix.getRowDimension();
        int numWords = realMatrix.getRow(0).length;
        int count = 0;
        for(int l : labels){
            count += l;
        }

        float pAbusive = (float)count / numTrainDocs;
        //生成单词矩阵
        RealMatrix p0Matrix = MatrixUtils.createRealMatrix(1, numWords);
        p0Matrix = oneNums(p0Matrix);
        RealMatrix p1Matrix = MatrixUtils.createRealMatrix(1,numWords);
        p1Matrix = oneNums(p1Matrix);

        float p0Denom = 2;
        float p1Denom = 2;

        //不同类别单词增加，总单词增加
        for (int i = 0; i < labels.size(); i++) {
            if(labels.get(i)==1){
                p1Matrix = p1Matrix.add(realMatrix.getRowMatrix(i));
                p1Denom += sumMatrix(realMatrix.getRowMatrix(i));
            }else{
                p0Matrix = p0Matrix.add(realMatrix.getRowMatrix(i));
                p0Denom += sumMatrix(realMatrix.getRowMatrix(i));
            }
        }

        //单词概率矩阵
        RealMatrix p0 = logMatrix(p0Matrix.scalarMultiply(1 / p0Denom));
        RealMatrix p1 = logMatrix(p1Matrix.scalarMultiply(1 / p1Denom));
        return new Object[]{p0,p1,pAbusive};
    }


    /**
     * 矩阵填充1
     * @param realMatrix
     * @return
     */
    public static RealMatrix oneNums(RealMatrix realMatrix){
        for(int i=0;i<realMatrix.getColumnDimension();i++){
            realMatrix.setColumn(i,new double[]{1});
        }
        return realMatrix;
    }

    /**
     * 计算矩阵元素和
     * @param realMatrix
     * @return
     */
    public static float sumMatrix(RealMatrix realMatrix){
        float num = 0;
        double[] rows = realMatrix.getRow(0);
        for(double row : rows){
            num += row;
        }
        return num;
    }

    /**
     * 矩阵元素log操作
     * @param realMatrix
     * @return
     */
    public static RealMatrix logMatrix(RealMatrix realMatrix){
        double[] rows = realMatrix.getRow(0);
        double[] newRows = new double[rows.length];
        for (int i = 0; i < rows.length; i++) {
            newRows[i] = log(rows[i]);
        }
        realMatrix.setRow(0,newRows);
        return realMatrix;
    }

    /**
     * 矩阵元素相乘
     * @param m1
     * @param m2
     * @return
     */
    public static RealMatrix multiply(RealMatrix m1,RealMatrix m2){
        RealVector r1 = m1.getRowVector(0);
        RealVector r2 = m2.getRowVector(0);
        RealMatrix m = MatrixUtils.createRealMatrix(m1.getRowDimension(),m1.getColumnDimension());
        m.setRowVector(0,r1.ebeMultiply(r2));

        return m;
    }


    /**
     * 验证方法
     * @param realMatrix
     * @param p0M
     * @param p1M
     * @return
     */
    public static int classify(RealMatrix realMatrix,Object p0M,Object p1M){

        float p0 = (float) (sumMatrix(multiply(realMatrix, (RealMatrix) p0M))+log(1.0-0.5));
        float p1 = (float) (sumMatrix(multiply(realMatrix, (RealMatrix) p1M))+log(0.5));
        if(p0>p1){
            return 0;
        }
        return 1;
    }



    public static void main(String[] args) throws IOException {
        /*Object[] retData = createData();
        ArrayList<String> set = createVocabSet((ArrayList<ArrayList<String>>) retData[0]);
        ArrayList<ArrayList<String>> lists = (ArrayList<ArrayList<String>>) retData[0];
        RealMatrix m = MatrixUtils.createRealMatrix(lists.size(),set.size());
        for (int i = 0; i < lists.size(); i++) {
            m.setRow(i,bagOfWords2VecMN(set,lists.get(i)));
        }

        Object[] retP = trainNB(m, (ArrayList<Integer>) retData[1]);

        ArrayList<String> test = Lists.newArrayList("love");
        RealMatrix m1 = MatrixUtils.createRealMatrix(1,set.size());
        m1.setRow(0,bagOfWords2VecMN(set,test));
        System.out.println(classify(m1,retP[0],retP[1]));*/
        createdata();


    }
}
